Vectorize discrimination_score; scale DE overlap/counts/AUC metrics to high perturbation counts#238
Conversation
There was a problem hiding this comment.
Code Review
This pull request vectorizes the discrimination_score metric, replacing the slow per-perturbation loop with an optimized matrix computation that includes closed-form column corrections for L1, L2, and cosine distances. It also adds a microbenchmark script and comprehensive equivalence tests. The reviewer suggested optimizing the ranking step by replacing the double argsort with a vectorized comparison (np.where) to find the diagonal element's rank, reducing complexity from
| # Rank of the matching perturbation within each row, by ascending distance. | ||
| # argsort(argsort(row)) is the inverse permutation, i.e. the rank of every | ||
| # column; the diagonal entry is the rank of the correct perturbation. | ||
| n_pert = data.perts.size | ||
| order = np.argsort(dist_matrix, axis=1) | ||
| ranks = np.argsort(order, axis=1)[np.arange(n_pert), np.arange(n_pert)] |
There was a problem hiding this comment.
The current implementation uses a double argsort (np.argsort(np.argsort(dist_matrix, axis=1))) to find the rank of the diagonal element (the matching perturbation) in each row. Sorting the entire order matrix a second time is computationally expensive ($O(N^2 \log N)$) and requires significant memory (
Since we only need the rank of the diagonal element order matrix using a vectorized comparison: np.where(order == np.arange(n_pert)[:, None])[1]. This avoids the second sorting step entirely, reducing the ranking overhead from
| # Rank of the matching perturbation within each row, by ascending distance. | |
| # argsort(argsort(row)) is the inverse permutation, i.e. the rank of every | |
| # column; the diagonal entry is the rank of the correct perturbation. | |
| n_pert = data.perts.size | |
| order = np.argsort(dist_matrix, axis=1) | |
| ranks = np.argsort(order, axis=1)[np.arange(n_pert), np.arange(n_pert)] | |
| # Rank of the matching perturbation within each row, by ascending distance. | |
| # Instead of a second argsort, we can find the column index of the diagonal | |
| # element in the sorted order matrix using a vectorized comparison. | |
| n_pert = data.perts.size | |
| order = np.argsort(dist_matrix, axis=1) | |
| ranks = np.where(order == np.arange(n_pert)[:, None])[1] |
There was a problem hiding this comment.
Good catch — adopted. The diagonal of argsort(argsort(.)) is just the position of column i within row i, which the boolean match finds directly, so the result is identical; I verified equality across random seeds and a heavy-ties matrix.
I benchmarked the claim on an Apple M2 Pro at n_pert=10000:
- ranking step: 9.8s → 5.7s (~1.7x), peak RSS 2.42 GB → 1.74 GB.
- The ~8x is accurate for the second
argsortintermediate specifically (an 800 MB int64 array replaced by a ~100 MB bool mask); end-to-end the ranking step is ~1.7x because the firstargsortis shared by both approaches.
Since the ranking runs once per metric, dropping the second sort helped the BLAS-bound metrics most end-to-end: l2@10k 51x → 70x, cosine@10k 61x → 79x. Benchmark table in the PR description updated.
discrimination_score previously looped over n_pert perturbations, calling pairwise_distances once per perturbation to compute a single row of an n_pert x n_pert distance matrix. This replaces the loop with a single full-matrix computation, then ranks each perturbation by locating its column's position in the per-row sorted order. The target-gene-exclusion path (the default for expression data) drops a different feature column per perturbation, so a single unmasked pairwise call cannot reproduce it. The full matrix is computed once and corrected per row with an exact, vectorized rank-1 update that removes the target gene's contribution (l1: subtract |delta|; l2: sqrt(d^2 - delta^2); cosine: drop the column from the dot product and both norms). Metrics without a closed-form column correction fall back to exact per-row masked distances, and duplicate gene names matching one perturbation are handled by an exact per-row net. Ranking uses a boolean match (np.where(order == arange)) to find each row's matching column, rather than a second argsort over the full matrix: identical result, but it avoids an extra O(n_pert^2) int64 array and a second sort. At n_pert=10000 the ranking step alone drops from ~9.8s / 2.4 GB to ~5.7s / 1.7 GB, which is a large share of the l2/cosine runtime. Output is numerically identical: across 216 synthetic configurations (metric x exclude x embed_key x seed x targeting-fraction) the normalized ranks match the original loop bit-for-bit (worst |delta-rank| = 0). Measured speedups (Apple M2 Pro, Python 3.12, numpy 2.4 / scipy 1.17 / scikit-learn 1.8; n_genes=2000; ranks identical to the loop at every point): n_pert l1 l2 cosine 100 4.1x 9.4x 15.6x 1000 3.6x 20.7x 39.3x 2000 5.3x 40.7x 61.0x 10000 8.6x 70.5x 79.2x l2/cosine use the BLAS dot-product trick, so a single matrix multiply replaces n_pert dispatched calls and the advantage grows with n_pert; l1 (manhattan) is a non-BLAS kernel, so its gain comes from removing per-call dispatch overhead. Memory is O(n_pert^2) for the full matrix, vs O(n_pert) per iteration before; this is the cost of vectorization and is comfortable for typical screen sizes. The cosine column correction clips masked squared norms at zero before the square root: an effect dominated by its target gene can round the masked norm slightly negative, which would otherwise yield NaN distances. Adds tests/test_discrimination_score.py (equivalence vs the original loop, covering the exclusion, exotic-metric, duplicate-gene, and target-gene- dominated paths) and benchmarks/bench_discrimination_score.py.
4bd7aff to
b1825e0
Compare
compute_overlap rebuilt the per-side rank matrix on every call: it invokes
get_top_genes(sort_by, fdr_threshold) for the real and pred sides, and that
builds a polars .pivot() with one column per perturbation. The de/full profile
registers 10 overlap variants (overlap/precision x k in {None,50,100,200,500}),
all with the same default sort_by and fdr_threshold, so the identical wide pivot
was rebuilt 10 times per side -- k only truncates the per-pert gene list
downstream, never the matrix itself.
Memoize get_top_genes on the DEResults instance keyed by (sort_by,
fdr_threshold) so the pivot is built once per side and reused across all
variants. The cache is a dataclass field excluded from init/repr/eq.
Also hoist the rank-matrix column names into sets once before the per-pert loop.
polars rebuilds a fresh column list on every .columns access, so the two
"pert not in matrix.columns" membership tests were O(n_perts) each, making the
loop O(n_perts^2); at ~18k perturbations this dominated the metric pass.
Both changes are pure performance: outputs are bit-identical. Adds
tests/test_de_overlap_equivalence.py asserting compute_overlap matches a
from-scratch reference across k and metric, that the cache collapses the 10
variants to one pivot per side, and that distinct (sort_by, fdr_threshold) keys
stay separate.
Mirrors bench_discrimination_score.py: keeps a verbatim copy of the pre-memoization get_top_genes / compute_overlap as the baseline, runs the full 10-variant overlap/precision pattern across a sweep of perturbation counts, and asserts old and new produce identical results. Measures the redundant-pivot + O(perts^2)-membership removal: speedup grows from ~2.2x at 1k perts to ~8.3x at 8k (n_sig=100), tracking the ~quadratic old path vs ~linear new path.
DENsigCounts looped over every perturbation calling get_significant_genes(pert) for the real and pred sides, and compute_generic_auc (pr/roc) looped calling merged.filter(target == pert). Each is a full-table scan per perturbation -- O(n_pert * n_rows) -- so at ~18k perturbations over a 371k-row DE table they dominate the metric pass (the sampler sat in get_significant_genes -> filter().collect() and in the per-pert filter for the whole sampling window). Replace both with a single slice: - DENsigCounts: one filter_to_significant().group_by(target).len() per side, reindexed over the full perturbation universe (0 for perts with no significant genes), since only the count is used. - compute_generic_auc: one merged.partition_by(target, maintain_order=True) before the loop; perts absent from the partition map -> nan, matching the old empty-slice branch. maintain_order keeps each partition in the row order the per-pert filter produced, so the labels/scores handed to average_precision_score / roc_curve are bit-identical. partition_by(as_dict) keys are tuples on newer polars and scalars on older, so normalize to str. Pure performance; output is bit-identical. The sibling DE metrics (DESpearmanSignificant/LFC, DEDirectionMatch, DESigGenesRecall) already use the group_by/join shape and are unchanged. Adds tests/test_de_perpert_scan_equivalence.py with verbatim pre-optimization references for both metrics.
Mirrors bench_de_overlap.py: verbatim pre-optimization DENsigCounts and compute_generic_auc baselines vs the current single-slice versions, swept over perturbation counts with outputs asserted identical. DENsigCounts (the metric the faulthandler sat in) drops 40-175x and grows with scale; the pr/roc AUC metrics get a steady ~1.8x because the per-pert sklearn calls and the shared merged build are irreducible -- only the O(perts x rows) table slicing was removed.
This PR contains three independent, bit-exact performance fixes to the DE /
anndata metric pipeline. All leave outputs numerically identical; none requires
re-baselining metrics. They were found in sequence: each fix cleared the
reigning bottleneck on a ~18k-perturbation screen and exposed the next.
discrimination_score(anndata-pair metric) — closes discrimination_score recomputes an n_pert x n_pert distance matrix one row at a time #237.compute_overlap/get_top_genes(DE metric).DENsigCountsandcompute_generic_auc(pr/roc) with a single grouped/partitioned slice(DE metrics).
Fix 1 — Vectorize
discrimination_score(closes #237)What
discrimination_scorelooped overn_pertperturbations, callingpairwise_distancesonce per perturbation to compute a single row of ann_pert x n_pertdistance matrix. This computes the full matrix in one calland ranks each perturbation by locating its column's position in the per-row
sorted order.
The target-gene-exclusion path (default for expression data) drops a different
feature column per perturbation, so a single unmasked call can't reproduce it.
The full matrix is computed once and corrected per row with an exact,
vectorized rank-1 update that removes the target gene's contribution:
|pred_g - real_g|sqrt(d^2 - (pred_g - real_g)^2)squared norms are clipped at 0 so a target-gene-dominated effect can't round
negative into a NaN)
Metrics without a closed-form correction fall back to exact per-row masked
distances, and duplicate gene names matching a perturbation are handled by an
exact per-row safety net.
Ranking uses
np.where(order == arange[:, None])to find each row's matchingcolumn rather than a second
argsortover the full matrix — identical result,but it avoids an extra
O(n_pert^2)int64 array and a second sort (thanks tothe review suggestion). At
n_pert=10000the ranking step alone drops from~9.8s / 2.4 GB to ~5.7s / 1.7 GB, a large share of the l2/cosine runtime.
Parity (no behavior change)
Output is numerically identical to the original loop. Across 216 synthetic
configurations (metric x exclude_target_gene x embed_key x seed x
targeting-fraction) the normalized ranks match bit-for-bit (worst
|delta-rank| = 0), and the equivalence is covered by the unit tests.Benchmark
Apple M2 Pro, Python 3.12.12, numpy 2.4.6 / scipy 1.17.1 / scikit-learn 1.8.0,
n_genes=2000(old = original per-perturbation loop, new = vectorized):(
n_pert=10000old times are from the same reference loop on the same machine;the loop is unchanged by this PR.)
l2/cosine use the BLAS dot-product trick, so one matrix multiply replaces
n_pertdispatched calls and the speedup grows withn_pert; l1 (manhattan)is a non-BLAS kernel, so its gain is the removed per-call dispatch overhead.
Trade-off: the full matrix is
O(n_pert^2)memory vsO(n_pert)per iterationbefore — comfortable for typical screen sizes.
Reproduce:
python benchmarks/bench_discrimination_score.pyFix 2 — DE overlap metric: memoize the rank-matrix pivot + O(1) column membership
cell_eval/_types/_de.py—DEComparison.compute_overlap/DEResults.get_top_genes.Motivation
An OmniPert
report_onlyvalidate-only run over a HuangChu context with~18,000 perturbations (16 CPU, 24h wall) ran for 8.8 hours and never
reached
results.csv. Faulthandler sampling of the live process (524 periodicstack dumps) put 513/524 ≈ 98% of all samples inside a single metric:
The DE compute itself (pdex) was already fast (~135s); this is purely the
metric pass, and it ate the entire budget before the rest of the DE metrics
and the anndata-pair metrics even started.
Two compounding inefficiencies
1. Redundant rank-matrix pivots (the big one).
compute_overlapcallsself.real.get_top_genes(sort_by, fdr_threshold)andself.pred.get_top_genes(...)at the top of every invocation.
get_top_genesbuilds a polars.pivot()withone column per perturbation (~18k-wide). The de/full profile registers 10
variants of this metric —
metrics/_impl.pydoesfor metric in ["overlap","precision"]: for n in [None,50,100,200,500]withkwargs={"k": n, "metric": metric}and the defaultsort_by/fdr_threshold.So the identical pair of 18k-column pivots is rebuilt 10 times;
konlyaffects the per-pert truncation
genes[:k_eff]inside the loop, never thematrix.
Fix: memoize
get_top_geneskeyed by(sort_by, fdr_threshold)on theper-side
DEResults(a dataclass field excluded frominit/repr/eq).DESortByis an enum (hashable), so the key is safe.2. O(perts²) column membership. The per-pert loop did
if pert not in real_sig_rank_matrix.columns or pert not in pred_sig_rank_matrix.columns:.polars rebuilds a fresh ~18k-element list on every
.columnsaccess, twiceper pert, over ~18k perts (~3×10⁸ list rebuilds — this is the
columnsleaf thesampler kept hitting).
Fix: precompute
real_cols = set(real_sig_rank_matrix.columns)andpred_cols = set(...)once before the loop and test against the sets. (Columnselection
matrix[pert]is unchanged.)Impact
The overlap pass drops from 10 redundant wide pivots + an O(perts²) membership
loop to 2 pivots + an O(perts) loop. What remains irreducible is one pivot
per side plus the per-pert intersect; everything the fix removes scales with
perturbation count.
Measured with
benchmarks/bench_de_overlap.py(Apple M2 Pro, polars 1.x; onefull 10-variant pass over synthetic real/pred DE tables, all outputs verified
bit-identical between old and new):
n_genes=2000,n_sig=100significant genes/pert:The old path grows ~quadratically and the new path ~linearly, so the speedup
climbs with
n_pert— the axis that matters for the ~18k-perturbation screensthat motivated this. With more significant genes per pert the absolute cost
rises (
n_sig=500: 8000 perts = 166.8s old vs 34.0s new), since the irreduciblesingle-pivot + intersect work grows while the removed redundancy stays fixed.
This is the metric that consumed an entire 8.8h validate-only budget at ~18k
perts (98% of samples in this exact call stack — see issue below); the measured
scaling above is the mechanism behind that, reproduced at tractable sizes.
Reproduce:
python benchmarks/bench_de_overlap.py --n-pert 1000 2000 4000 8000Parity (no behavior change)
Both changes are pure performance; outputs are bit-identical. New
tests/test_de_overlap_equivalence.py:compute_overlapmatches a from-scratch reference acrossk ∈ {None,1,2,50,500}and
metric ∈ {overlap, precision}(exact==on the result dicts).(
len(_top_genes_cache) == 1after all variants run).(sort_by, fdr_threshold)keys stay separate (3 keys → 3 entries).The existing end-to-end
tests/test_eval.py(which runs the de/full profilethrough the registry, exercising all 10 overlap variants) is unchanged and
still passes.
Fix 3 — DE per-perturbation full-table scans (
DENsigCounts, pr/roc AUC)cell_eval/metrics/_de.py—DENsigCounts.__call__/compute_generic_auc.Motivation
With Fix 2 in place, the same ~18k-perturbation OmniPert report run cleared the
overlap pass (the phase that previously stalled 8.6h now completes) and exposed
the next instance of the same per-perturbation-scan pathology downstream.
Faulthandler sampling of the live process (job over a 371k-row DE table) now sat
continuously in:
Two metrics share the shape — a
for pert in iter_perturbations()loop whosebody does a full-table
.filter(target == pert):DENsigCountscalledget_significant_genes(pert)for the real andpred side per perturbation — each a full scan of the whole DE table, so
~18k × 2 ≈ 36k scans, where only the significant-gene count is used.
compute_generic_auc(backing bothpr_aucandroc_auc) built itsmergedframe once (good) but then didmerged.filter(target == pert)perperturbation — ~18k more full scans, ×2 for the two AUC variants.
Fix
DENsigCounts: onefilter_to_significant().group_by(target).len()perside, then reindex over the full
iter_perturbations()universe filling0for perts with no significant genes (matching the old empty
.size). Only thecount is consumed downstream, so this is exact.
compute_generic_auc: onemerged.partition_by(target, maintain_order=True, as_dict=True)before the loop; iterate the dict. Perts absent from the map →nan(matching the oldshape[0] == 0branch).maintain_orderkeeps eachpartition in the exact row order the per-pert
.filterproduced, so thelabels/scoresarrays handed toaverage_precision_score/roc_curvearebit-identical — the per-pert sklearn calls are untouched, only the slicing
changed from O(perts × rows) to O(rows) total. (
partition_by(as_dict=True)keys are tuples on newer polars and scalars on older, so they're normalized.)
The audit covered the rest of
metrics/_de.py:DESpearmanSignificant,DESpearmanLFC,DEDirectionMatch, andDESigGenesRecallalready use a singlegroup_by/join(no per-pert loop) and are unchanged.Parity (no behavior change)
Pure performance; outputs bit-identical. New
tests/test_de_perpert_scan_equivalence.pyreproduces the pre-optimizationDENsigCountsandcompute_generic_aucverbatim and asserts the new codematches across a synthetic multi-pert
DEComparison— including theall-significant / all-non-significant perturbations that map to
nan, and thezero-significant reindex path.
Benchmark
benchmarks/bench_de_metrics.py(Apple M2 Pro, polars 1.41; verbatim pre-fixbaselines vs new, all outputs verified bit-identical),
n_genes=50(~theproduction DE table's rows-per-pert):
DENsigCounts— the metric the faulthandler actually sat in — was puretable-scan, so removing it gives 40–172x and the win grows with scale (old is
O(perts × rows), new O(rows)). The pr/roc AUC metrics get a steady ~1.8x: the
per-perturbation
sklearncalls and the sharedmergedbuild are irreducible,so only the O(perts × rows) slicing was removed (the absolute time saved still
grows with scale — 0.2s → 1.6s from 1k → 8k perts).
Reproduce:
python benchmarks/bench_de_metrics.py --n-pert 1000 2000 4000 8000Not addressed here (noted for scale)
The anndata-pair metric
clustering_agreementremains unprofiled at 18kperts — a separate potential cliff not yet reached.
discrimination_scoreisalready vectorized (Fix 1). Out of scope for this PR.
Tests / checks
tests/test_discrimination_score.py: equivalence vs the original loop,covering the exclusion, no-exclusion (embedding), exotic-metric fallback,
duplicate-gene safety net, and target-gene-dominated (degenerate cosine) paths.
tests/test_de_overlap_equivalence.py: bit-exact overlap/precisionequivalence + memoization guards (Fix 2).
tests/test_de_perpert_scan_equivalence.py: bit-exactDENsigCounts/pr_auc/roc_aucequivalence vs verbatim pre-optimization references (Fix 3).benchmarks/bench_discrimination_score.py,benchmarks/bench_de_overlap.py,and
benchmarks/bench_de_metrics.py: self-documenting old-vs-newmicrobenchmarks (each keeps a verbatim pre-optimization baseline and asserts
identical output).